"""
Phase 4b: Probing the Post-GFC Structural Break
==================================================
The rolling window results from Phase 4 showed a dramatic break: Z polynomials
strongly predict Japanification pre-GFC (all p<0.01) but reverse sign post-GFC.

Three competing stories:
  A. Monetary policy suppression (QE/ZIRP masked the demographic signal)
  B. Mechanical — rate component lost variation at the zero lower bound
  C. Something genuinely changed in the demographics-macro relationship

Tests:
  1. Component-by-component rolling windows (which channel broke?)
  2. QE vs non-QE country split
  3. 2-component index rolling windows (remove rate channel entirely)
  4. Policy rate proximity to ZLB as control
  5. Fiscal impulse control (change in structural fiscal balance)

Input:  japanification/data/processed/japan_panel_indexed.csv
Output: japanification/output/tables/phase4b_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
JAPAN_DIR = PROJECT_DIR / "japanification"
PROCESSED_DIR = JAPAN_DIR / "data" / "processed"
TABLE_DIR = JAPAN_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_quick(y, X, entity_ids, time_ids):
    """Fit PanelGLS silently, return model."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    return model


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS with printed summary."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n  {label}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    return model, result_df


# QE countries: conducted large-scale asset purchases post-2008
QE_COUNTRIES = [
    'USA', 'GBR', 'JPN', 'SWE', 'CHE',
    # Eurozone members
    'AUT', 'BEL', 'CYP', 'EST', 'FIN', 'FRA', 'DEU', 'GRC', 'IRL',
    'ITA', 'LVA', 'LTU', 'LUX', 'MLT', 'NLD', 'PRT', 'SVK', 'SVN', 'ESP',
]


def main():
    print("=" * 70)
    print("PHASE 4b: Probing the Post-GFC Structural Break")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "japan_panel_indexed.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    base_vars = demo_vars + controls

    all_results = []

    # =================================================================
    # TEST 1: Component-by-component rolling windows
    # =================================================================
    # Which component's relationship with demographics broke?
    print(f"\n{'=' * 70}")
    print("  TEST 1: Component-by-Component Rolling Windows")
    print("=" * 70)

    components = {
        'growth': 'rgdp_growth',
        'inflation': 'inflation_japan',
        'rate': 'rate_japan',
    }

    component_window_rows = []
    for comp_name, comp_var in components.items():
        for start_year in range(1990, 2011):
            end_year = start_year + 14
            win = df[(df['year'] >= start_year) & (df['year'] <= end_year)]
            win_est = win.dropna(subset=[comp_var] + base_vars)

            if len(win_est) < 100:
                continue

            model = fit_quick(
                win_est[comp_var].values, win_est[base_vars].values,
                win_est['iso3'].values, win_est['year'].values
            )

            row = {
                'component': comp_name,
                'window_start': start_year,
                'window_end': end_year,
                'N': model.n_obs,
                'N_countries': model.n_countries,
                'R_squared': model.r_squared,
            }
            for i, zv in enumerate(demo_vars):
                row[f'{zv}_coef'] = model.beta[i]
                row[f'{zv}_pval'] = model.pvalues[i]
            component_window_rows.append(row)

    comp_windows = pd.DataFrame(component_window_rows)
    comp_windows.to_csv(TABLE_DIR / "phase4b_component_rolling.csv", index=False)

    # Print summary: Z_1 coefficient by component across windows
    print("\n  Z_1 coefficient by component across rolling windows:")
    print(f"  {'Window':<12} {'Growth':>10} {'p':>6} {'Inflation':>10} {'p':>6} {'Rate':>10} {'p':>6}")
    print("  " + "-" * 62)
    for start in range(1990, 2011, 3):
        end = start + 14
        row_parts = [f"  {start}-{end}"]
        for comp in ['growth', 'inflation', 'rate']:
            sub = comp_windows[(comp_windows['component'] == comp) &
                               (comp_windows['window_start'] == start)]
            if len(sub) > 0:
                coef = sub['Z_1_coef'].values[0]
                pval = sub['Z_1_pval'].values[0]
                sig = '***' if pval < 0.01 else '**' if pval < 0.05 else '*' if pval < 0.1 else ''
                row_parts.append(f"{coef:>10.3f} {pval:>5.3f}{sig}")
            else:
                row_parts.append(f"{'---':>10} {'':>6}")
        print("  ".join(row_parts))

    # Identify which component drove the break
    print("\n  DIAGNOSIS: Which component's Z relationship broke?")
    for comp in ['growth', 'inflation', 'rate']:
        sub = comp_windows[comp_windows['component'] == comp]
        early = sub[sub['window_start'] <= 1995]
        late = sub[sub['window_start'] >= 2005]
        if len(early) > 0 and len(late) > 0:
            early_coef = early['Z_1_coef'].mean()
            late_coef = late['Z_1_coef'].mean()
            early_sig = (early['Z_1_pval'] < 0.1).mean() * 100
            late_sig = (late['Z_1_pval'] < 0.1).mean() * 100
            print(f"    {comp:>10}: early Z₁={early_coef:+.3f} (sig {early_sig:.0f}%) → "
                  f"late Z₁={late_coef:+.3f} (sig {late_sig:.0f}%)")

    # =================================================================
    # TEST 2: QE vs non-QE country split
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  TEST 2: QE vs Non-QE Country Split")
    print("=" * 70)

    dep_var = 'japan_index_2c'

    for qe_label, qe_mask in [('QE countries', df['iso3'].isin(QE_COUNTRIES)),
                                ('Non-QE countries', ~df['iso3'].isin(QE_COUNTRIES))]:
        sub = df[qe_mask]
        print(f"\n  --- {qe_label} ({sub['iso3'].nunique()} countries) ---")

        for period_label, period_mask in [('Pre-GFC (1990-2007)', sub['year'] <= 2007),
                                           ('Post-GFC (2009-2024)', sub['year'] >= 2009)]:
            est = sub[period_mask].dropna(subset=[dep_var] + base_vars)
            if len(est) < 50:
                print(f"    {period_label}: insufficient obs ({len(est)})")
                continue

            m, r = fit_and_report(
                est[dep_var].values, est[base_vars].values,
                est['iso3'].values, est['year'].values,
                base_vars, f"{qe_label}: {period_label}"
            )
            all_results.append(r)

    # Summary comparison
    print(f"\n  QE vs Non-QE Break Summary (Z₁ coefficient):")
    print(f"  {'Group':<25} {'Pre-GFC':>10} {'Post-GFC':>10} {'Change':>10}")
    print("  " + "-" * 55)
    for group_label in ['QE countries', 'Non-QE countries']:
        pre = [r for r in all_results
               if r['model'].values[0].startswith(group_label) and 'Pre-GFC' in r['model'].values[0]]
        post = [r for r in all_results
                if r['model'].values[0].startswith(group_label) and 'Post-GFC' in r['model'].values[0]]
        if pre and post:
            pre_z1 = pre[0].loc[pre[0]['variable'] == 'Z_1', 'coefficient'].values[0]
            post_z1 = post[0].loc[post[0]['variable'] == 'Z_1', 'coefficient'].values[0]
            print(f"  {group_label:<25} {pre_z1:>10.3f} {post_z1:>10.3f} {post_z1 - pre_z1:>10.3f}")

    # =================================================================
    # TEST 3: 2-component index rolling windows (no rate channel)
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  TEST 3: 2-Component Index Rolling Windows (Rate Channel Removed)")
    print("=" * 70)

    idx_2c_rows = []
    for start_year in range(1990, 2011):
        end_year = start_year + 14
        win = df[(df['year'] >= start_year) & (df['year'] <= end_year)]
        win_est = win.dropna(subset=['japan_index_2c'] + base_vars)

        if len(win_est) < 100:
            continue

        model = fit_quick(
            win_est['japan_index_2c'].values, win_est[base_vars].values,
            win_est['iso3'].values, win_est['year'].values
        )

        row = {
            'index': '2-component',
            'window_start': start_year,
            'window_end': end_year,
            'N': model.n_obs,
            'R_squared': model.r_squared,
        }
        for i, zv in enumerate(demo_vars):
            row[f'{zv}_coef'] = model.beta[i]
            row[f'{zv}_pval'] = model.pvalues[i]
        idx_2c_rows.append(row)

    idx_2c_df = pd.DataFrame(idx_2c_rows)
    idx_2c_df.to_csv(TABLE_DIR / "phase4b_2c_rolling.csv", index=False)

    # Compare with the 3-component rolling windows
    idx_3c_rows = []
    for start_year in range(1990, 2011):
        end_year = start_year + 14
        win = df[(df['year'] >= start_year) & (df['year'] <= end_year)]
        win_est = win.dropna(subset=['japan_index_3c'] + base_vars)

        if len(win_est) < 100:
            continue

        model = fit_quick(
            win_est['japan_index_3c'].values, win_est[base_vars].values,
            win_est['iso3'].values, win_est['year'].values
        )

        row = {
            'index': '3-component',
            'window_start': start_year,
            'window_end': end_year,
            'N': model.n_obs,
            'R_squared': model.r_squared,
        }
        for i, zv in enumerate(demo_vars):
            row[f'{zv}_coef'] = model.beta[i]
            row[f'{zv}_pval'] = model.pvalues[i]
        idx_3c_rows.append(row)

    idx_3c_df = pd.DataFrame(idx_3c_rows)

    # Print comparison
    print(f"\n  Z₁ coefficient: 2-component vs 3-component index")
    print(f"  {'Window':<12} {'2c coef':>10} {'p':>6} {'3c coef':>10} {'p':>6}")
    print("  " + "-" * 50)
    for start in range(1990, 2011, 2):
        end = start + 14
        r2 = idx_2c_df[idx_2c_df['window_start'] == start]
        r3 = idx_3c_df[idx_3c_df['window_start'] == start]
        parts = [f"  {start}-{end}"]
        if len(r2) > 0:
            parts.append(f"{r2['Z_1_coef'].values[0]:>10.3f} {r2['Z_1_pval'].values[0]:>5.3f}")
        else:
            parts.append(f"{'---':>10} {'':>6}")
        if len(r3) > 0:
            parts.append(f"{r3['Z_1_coef'].values[0]:>10.3f} {r3['Z_1_pval'].values[0]:>5.3f}")
        else:
            parts.append(f"{'---':>10} {'':>6}")
        print("  ".join(parts))

    # Key diagnostic
    early_2c = idx_2c_df[idx_2c_df['window_start'] <= 1995]['Z_1_coef'].mean()
    late_2c = idx_2c_df[idx_2c_df['window_start'] >= 2005]['Z_1_coef'].mean()
    early_2c_sig = (idx_2c_df[idx_2c_df['window_start'] <= 1995]['Z_1_pval'] < 0.1).mean() * 100
    late_2c_sig = (idx_2c_df[idx_2c_df['window_start'] >= 2005]['Z_1_pval'] < 0.1).mean() * 100

    print(f"\n  DIAGNOSIS (2-component, no rates):")
    print(f"    Early windows Z₁: {early_2c:+.3f} (sig {early_2c_sig:.0f}%)")
    print(f"    Late windows Z₁:  {late_2c:+.3f} (sig {late_2c_sig:.0f}%)")
    if abs(late_2c - early_2c) > 1.0 and late_2c * early_2c < 0:
        print(f"    → Break PERSISTS even without rate channel (Story B rejected)")
    elif late_2c_sig < 30 and early_2c_sig > 60:
        print(f"    → Break weakened but present (partially Story B)")
    else:
        print(f"    → Break GONE without rate channel (Story B supported — it was the rates)")

    # =================================================================
    # TEST 4: Policy rate proximity to ZLB
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  TEST 4: Zero Lower Bound Proximity")
    print("=" * 70)

    # Construct ZLB proximity: how close is the policy rate to zero?
    # Use policy_rate where available, else rate_japan
    df['near_zlb'] = 0
    rate_col = 'policy_rate' if 'policy_rate' in df.columns else 'rate_japan'
    df.loc[df[rate_col].notna() & (df[rate_col] <= 1.0), 'near_zlb'] = 1

    print(f"  Countries near ZLB (rate ≤ 1%): "
          f"{df['near_zlb'].sum()} obs, "
          f"{df.loc[df['near_zlb'] == 1, 'iso3'].nunique()} countries")

    # Interact Z with ZLB proximity
    for zv in demo_vars:
        df[f'{zv}_x_zlb'] = df[zv] * df['near_zlb']

    zlb_int_vars = [f'{zv}_x_zlb' for zv in demo_vars]
    zlb_vars = base_vars + ['near_zlb'] + zlb_int_vars

    # Post-GFC only (when ZLB became relevant)
    post_gfc = df[df['year'] >= 2009].dropna(subset=[dep_var] + zlb_vars)
    if len(post_gfc) >= 100:
        m_zlb, r_zlb = fit_and_report(
            post_gfc[dep_var].values, post_gfc[zlb_vars].values,
            post_gfc['iso3'].values, post_gfc['year'].values,
            zlb_vars, "Post-GFC with ZLB interaction"
        )
        all_results.append(r_zlb)

        # Interpret: Z effect away from ZLB vs at ZLB
        z1_base = m_zlb.beta[0]  # Z_1 coefficient
        z1_zlb = m_zlb.beta[zlb_vars.index('Z_1_x_zlb')]  # Z_1 × ZLB interaction
        print(f"\n  Z₁ effect away from ZLB: {z1_base:.3f}")
        print(f"  Z₁ effect at ZLB:        {z1_base + z1_zlb:.3f} (interaction: {z1_zlb:.3f})")

    # Full sample with ZLB interaction
    full_zlb = df.dropna(subset=[dep_var] + zlb_vars)
    if len(full_zlb) >= 200:
        m_zlb_full, r_zlb_full = fit_and_report(
            full_zlb[dep_var].values, full_zlb[zlb_vars].values,
            full_zlb['iso3'].values, full_zlb['year'].values,
            zlb_vars, "Full sample with ZLB interaction"
        )
        all_results.append(r_zlb_full)

    # =================================================================
    # TEST 5: Fiscal impulse control
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  TEST 5: Fiscal Impulse Control")
    print("=" * 70)

    # Fiscal impulse = change in fiscal balance (negative = stimulus)
    df = df.sort_values(['iso3', 'year'])
    df['fiscal_impulse'] = df.groupby('iso3')['fiscal_bal_gdp'].diff()

    # Pre/post with fiscal impulse
    for label, mask in [('Pre-GFC + fiscal impulse', df['year'] <= 2007),
                        ('Post-GFC + fiscal impulse', df['year'] >= 2009)]:
        fisc_vars = base_vars + ['fiscal_impulse']
        est = df[mask].dropna(subset=[dep_var] + fisc_vars)
        if len(est) >= 100:
            m_fi, r_fi = fit_and_report(
                est[dep_var].values, est[fisc_vars].values,
                est['iso3'].values, est['year'].values,
                fisc_vars, label
            )
            all_results.append(r_fi)

    # Post-GFC: does controlling for fiscal impulse restore the Z signal?
    print("\n  DIAGNOSIS: Does fiscal impulse explain the post-GFC break?")
    post_results = [r for r in all_results if 'Post-GFC + fiscal' in r['model'].values[0]]
    if post_results:
        z1_row = post_results[0][post_results[0]['variable'] == 'Z_1']
        if len(z1_row) > 0:
            z1_coef = z1_row['coefficient'].values[0]
            z1_p = z1_row['p_value'].values[0]
            fi_row = post_results[0][post_results[0]['variable'] == 'fiscal_impulse']
            fi_coef = fi_row['coefficient'].values[0] if len(fi_row) > 0 else np.nan
            fi_p = fi_row['p_value'].values[0] if len(fi_row) > 0 else np.nan
            print(f"    Z₁ with fiscal control: {z1_coef:.3f} (p={z1_p:.3f})")
            print(f"    Fiscal impulse:         {fi_coef:.4f} (p={fi_p:.3f})")

    # =================================================================
    # ADDITIONAL: Post-GFC with period interactions (2009-2015 vs 2016-2024)
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  ADDITIONAL: Post-GFC Sub-Periods")
    print("=" * 70)

    for label, yr_range in [('Post-GFC early (2009-2015)', (2009, 2015)),
                             ('Post-GFC late (2016-2024)', (2016, 2024))]:
        sub = df[(df['year'] >= yr_range[0]) & (df['year'] <= yr_range[1])]
        est = sub.dropna(subset=[dep_var] + base_vars)
        if len(est) >= 100:
            m_sub, r_sub = fit_and_report(
                est[dep_var].values, est[base_vars].values,
                est['iso3'].values, est['year'].values,
                base_vars, label
            )
            all_results.append(r_sub)

    # =================================================================
    # SYNTHESIS
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  SYNTHESIS: What Explains the Break?")
    print("=" * 70)

    results_df = pd.concat(all_results, ignore_index=True)

    # Z_1 summary across all models
    z1_summary = (results_df[results_df['variable'] == 'Z_1']
                  [['model', 'coefficient', 'std_error', 'p_value']]
                  .sort_values('model'))
    print("\n  Z₁ coefficient across all structural break tests:")
    for _, row in z1_summary.iterrows():
        sig = '***' if row['p_value'] < 0.01 else '**' if row['p_value'] < 0.05 else '*' if row['p_value'] < 0.1 else ''
        print(f"    {row['model']:<45} {row['coefficient']:>8.3f} (p={row['p_value']:.3f}){sig}")

    # Save
    results_df.to_csv(TABLE_DIR / "phase4b_structural_break_results.csv", index=False)
    z1_summary.to_csv(TABLE_DIR / "phase4b_z1_summary.csv", index=False)
    print(f"\nSaved to {TABLE_DIR / 'phase4b_*.csv'}")


if __name__ == "__main__":
    main()
